K-Nearest Neighbors is a supervised machine learning algorithm to predict whether a data point belongs to a class. The training data is labeled and then the data point looks at the nearest K number of points. The class with the largest number of occurrences within the K closest data points is then assumed to be the correct class. We use labeled pumpkin seed data to create a model and predict the correct class of the testing data. We use the class package in R to access the KNN algorithm. Then we scale, split, and evaluate the accuracy of the model. Scatterplots are used for visualizations.
Original Dataset: https://www.kaggle.com/datasets/muratkokludataset/pumpkin-seeds-dataset
install.packages('tidyverse', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'tidyverse' successfully unpacked and MD5 sums checked
##
## The downloaded binary packages are in
## C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('ggplot2', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'ggplot2' successfully unpacked and MD5 sums checked
##
## The downloaded binary packages are in
## C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('class', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'class' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'class'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\class\libs\x64\class.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\class\libs\x64\class.dll:
## Permission denied
## Warning: restored 'class'
##
## The downloaded binary packages are in
## C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('readxl', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'readxl' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'readxl'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\readxl\libs\x64\readxl.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\readxl\libs\x64\readxl.dll:
## Permission denied
## Warning: restored 'readxl'
##
## The downloaded binary packages are in
## C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages("caret", repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'caret' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'caret'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\caret\libs\x64\caret.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\caret\libs\x64\caret.dll:
## Permission denied
## Warning: restored 'caret'
##
## The downloaded binary packages are in
## C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
library(tidyverse)
## ── Attaching packages
## ───────────────────────────────────────
## tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0 ✔ purrr 0.3.5
## ✔ tibble 3.1.8 ✔ dplyr 1.0.10
## ✔ tidyr 1.2.1 ✔ stringr 1.4.1
## ✔ readr 2.1.3 ✔ forcats 0.5.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(ggplot2)
library(class)
library(readxl)
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
##
## The following object is masked from 'package:purrr':
##
## lift
pumpkin_data <- read_xlsx('Pumpkin_Seeds_Dataset.xlsx')
str(pumpkin_data)
## tibble [2,500 × 13] (S3: tbl_df/tbl/data.frame)
## $ Area : num [1:2500] 56276 76631 71623 66458 66107 ...
## $ Perimeter : num [1:2500] 888 1068 1083 992 998 ...
## $ Major_Axis_Length: num [1:2500] 326 417 436 382 384 ...
## $ Minor_Axis_Length: num [1:2500] 220 234 211 223 220 ...
## $ Convex_Area : num [1:2500] 56831 77280 72663 67118 67117 ...
## $ Equiv_Diameter : num [1:2500] 268 312 302 291 290 ...
## $ Eccentricity : num [1:2500] 0.738 0.828 0.875 0.812 0.819 ...
## $ Solidity : num [1:2500] 0.99 0.992 0.986 0.99 0.985 ...
## $ Extent : num [1:2500] 0.745 0.715 0.74 0.74 0.675 ...
## $ Roundness : num [1:2500] 0.896 0.844 0.767 0.849 0.834 ...
## $ Aspect_Ration : num [1:2500] 1.48 1.78 2.07 1.71 1.74 ...
## $ Compactness : num [1:2500] 0.821 0.749 0.693 0.762 0.756 ...
## $ Class : chr [1:2500] "Çerçevelik" "Çerçevelik" "Çerçevelik" "Çerçevelik" ...
summary(pumpkin_data)
## Area Perimeter Major_Axis_Length Minor_Axis_Length
## Min. : 47939 Min. : 868.5 Min. :320.8 Min. :152.2
## 1st Qu.: 70765 1st Qu.:1048.8 1st Qu.:415.0 1st Qu.:211.2
## Median : 79076 Median :1123.7 Median :449.5 Median :224.7
## Mean : 80658 Mean :1130.3 Mean :456.6 Mean :225.8
## 3rd Qu.: 89758 3rd Qu.:1203.3 3rd Qu.:492.7 3rd Qu.:240.7
## Max. :136574 Max. :1559.5 Max. :661.9 Max. :305.8
## Convex_Area Equiv_Diameter Eccentricity Solidity
## Min. : 48366 Min. :247.1 Min. :0.4921 Min. :0.9186
## 1st Qu.: 71512 1st Qu.:300.2 1st Qu.:0.8317 1st Qu.:0.9883
## Median : 79872 Median :317.3 Median :0.8637 Median :0.9903
## Mean : 81508 Mean :319.3 Mean :0.8609 Mean :0.9895
## 3rd Qu.: 90798 3rd Qu.:338.1 3rd Qu.:0.8970 3rd Qu.:0.9915
## Max. :138384 Max. :417.0 Max. :0.9481 Max. :0.9944
## Extent Roundness Aspect_Ration Compactness
## Min. :0.4680 Min. :0.5546 Min. :1.149 Min. :0.5608
## 1st Qu.:0.6589 1st Qu.:0.7519 1st Qu.:1.801 1st Qu.:0.6635
## Median :0.7130 Median :0.7977 Median :1.984 Median :0.7077
## Mean :0.6932 Mean :0.7915 Mean :2.042 Mean :0.7041
## 3rd Qu.:0.7402 3rd Qu.:0.8343 3rd Qu.:2.262 3rd Qu.:0.7435
## Max. :0.8296 Max. :0.9396 Max. :3.144 Max. :0.9049
## Class
## Length:2500
## Class :character
## Mode :character
##
##
##
head(pumpkin_data)
## # A tibble: 6 × 13
## Area Perimeter Major…¹ Minor…² Conve…³ Equiv…⁴ Eccen…⁵ Solid…⁶ Extent Round…⁷
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 56276 888. 326. 220. 56831 268. 0.738 0.990 0.745 0.896
## 2 76631 1068. 417. 234. 77280 312. 0.828 0.992 0.715 0.844
## 3 71623 1083. 436. 211. 72663 302. 0.875 0.986 0.74 0.767
## 4 66458 992. 382. 223. 67118 291. 0.812 0.990 0.740 0.849
## 5 66107 998. 384. 220. 67117 290. 0.819 0.985 0.675 0.834
## 6 73191 1041. 406. 231. 73969 305. 0.822 0.990 0.716 0.848
## # … with 3 more variables: Aspect_Ration <dbl>, Compactness <dbl>, Class <chr>,
## # and abbreviated variable names ¹Major_Axis_Length, ²Minor_Axis_Length,
## # ³Convex_Area, ⁴Equiv_Diameter, ⁵Eccentricity, ⁶Solidity, ⁷Roundness
ggplot(data=pumpkin_data) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class,colour=Class)) +
labs(y = "Aspect Ratio",x="Area") +
ggtitle("Scatter Plot of Pumpkin Seeds") + theme(plot.title = element_text(hjust = 0.5))
pumpkin_data$Species <- ifelse(pumpkin_data$Class == "Çerçevelik",1,2)
str(pumpkin_data)
## tibble [2,500 × 14] (S3: tbl_df/tbl/data.frame)
## $ Area : num [1:2500] 56276 76631 71623 66458 66107 ...
## $ Perimeter : num [1:2500] 888 1068 1083 992 998 ...
## $ Major_Axis_Length: num [1:2500] 326 417 436 382 384 ...
## $ Minor_Axis_Length: num [1:2500] 220 234 211 223 220 ...
## $ Convex_Area : num [1:2500] 56831 77280 72663 67118 67117 ...
## $ Equiv_Diameter : num [1:2500] 268 312 302 291 290 ...
## $ Eccentricity : num [1:2500] 0.738 0.828 0.875 0.812 0.819 ...
## $ Solidity : num [1:2500] 0.99 0.992 0.986 0.99 0.985 ...
## $ Extent : num [1:2500] 0.745 0.715 0.74 0.74 0.675 ...
## $ Roundness : num [1:2500] 0.896 0.844 0.767 0.849 0.834 ...
## $ Aspect_Ration : num [1:2500] 1.48 1.78 2.07 1.71 1.74 ...
## $ Compactness : num [1:2500] 0.821 0.749 0.693 0.762 0.756 ...
## $ Class : chr [1:2500] "Çerçevelik" "Çerçevelik" "Çerçevelik" "Çerçevelik" ...
## $ Species : num [1:2500] 1 1 1 1 1 1 1 1 1 1 ...
Pumpkin_data1 <- pumpkin_data %>% select (-13)
Pumpkin_data1[,1:12] <- scale(Pumpkin_data1[,1:12])
set.seed(1234)
index <- sample(2,nrow(Pumpkin_data1),replace = TRUE, prob=c(0.8,0.2))
training_data <- Pumpkin_data1[index==1,]
testing_data <-Pumpkin_data1[index==2,]
training_label <- training_data %>% select(13)
testing_label <- testing_data %>% select(13)
training_data <- training_data %>% select(-13)
testing_data <- testing_data %>% select(-13)
K = print(round(sqrt(nrow(training_data))))
## [1] 45
predictions <- knn(train=training_data , test=testing_data , cl=as.matrix(training_label) , k=round(sqrt(nrow(training_data))))
Accuracy <- table(tesing_lables = testing_label$Species, knn_prediction = predictions )
confusionMatrix(Accuracy)
## Confusion Matrix and Statistics
##
## knn_prediction
## tesing_lables 1 2
## 1 248 15
## 2 35 195
##
## Accuracy : 0.8986
## 95% CI : (0.8685, 0.9238)
## No Information Rate : 0.574
## P-Value [Acc > NIR] : < 2e-16
##
## Kappa : 0.7951
##
## Mcnemar's Test P-Value : 0.00721
##
## Sensitivity : 0.8763
## Specificity : 0.9286
## Pos Pred Value : 0.9430
## Neg Pred Value : 0.8478
## Prevalence : 0.5740
## Detection Rate : 0.5030
## Detection Prevalence : 0.5335
## Balanced Accuracy : 0.9024
##
## 'Positive' Class : 1
##
df_training <- training_data
df_training$Class <- training_label$Species
df_training$Class <- as.factor(df_training$Class)
ggplot(data=df_training) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class)) +
labs(y = "Aspect Ratio",x="Area") +
ggtitle("Scatter Plot of Pumpkin Seed Training Data") + theme(plot.title = element_text(hjust = 0.5))
df_testing <- testing_data
df_testing$Class <- testing_label$Species
df_testing$Prediction <- predictions
df_testing$Accuracy <- ifelse(df_testing$Class == df_testing$Prediction ,"Correct","Incorrect")
df_testing$Class = as.factor(df_testing$Class)
ggplot(data=df_testing) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class,colour=Accuracy)) +
scale_color_manual(values = c("Correct" = "green", "Incorrect" = "black")) +
labs(y = "Aspect Ratio",x="Area") +
ggtitle("Scatter Plot of Pumpkin Seed Testing Data using KNN") + theme(plot.title = element_text(hjust = 0.5))